/*
 * SPDX-License-Identifier: Apache-2.0
 */

// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.

#pragma once

#include "onnxoptimizer/pass.h"

namespace ONNX_NAMESPACE {
namespace optimization {

struct FuseConsecutiveTransposes final : public PredicateBasedPass {
  explicit FuseConsecutiveTransposes()
      : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete,
                           PassOptimizationType::Compute) {}

  std::string getPassName() const override {
    return "fuse_consecutive_transposes";
  }

  // returns a vector `ret` such that transposing by `ret` is equivalent
  // to transposing by `t1` and then by `t2`
  std::vector<int64_t> compose_transposes(const std::vector<int64_t>& t1,
                                          const std::vector<int64_t>& t2) {
    ONNX_ASSERT(t1.size() == t2.size());
    std::vector<int64_t> ret;
    ret.reserve(t1.size());
    for (size_t i = 0; i < t1.size(); i++) {
      ONNX_ASSERT(t2[i] < static_cast<int64_t>(t1.size()));
      ONNX_ASSERT(t1[static_cast<size_t>(t2[i])] <
                  static_cast<int64_t>(t1.size()));
      ret.push_back(t1[static_cast<size_t>(t2[i])]);
    }
    return ret;
  }

  bool patternMatchPredicate(Node* node) override {
    return node->kind() == kTranspose &&
           node->input()->node()->kind() == kTranspose;
  }

  bool runTransform(Node* n, Graph&,
                    NodeDestroyType& destroy_current) override {
    auto origInput = n->input();
    if (!n->hasAttribute(kperm) && !origInput->node()->hasAttribute(kperm)) {
      // One special case (two consecutive transposes with no perm,
      // since we do not have the shape information here, we have
      // to eliminate two transpose together.
      if (n->output()->has_sizes()) {
        origInput->node()->input()->setSizes(n->output()->sizes());
      }
      const bool replacing_success =
          tryReplacingAllUsesWith(n, origInput->node()->input()->node());
      if (!replacing_success) {
        return false;
      }
      // only destroy the first transpose and DCE will take care of the another
      destroy_current = NodeDestroyType::DestroyOne;
      return true;
    }
    if (!n->hasAttribute(kperm) || !origInput->node()->hasAttribute(kperm)) {
      destroy_current = NodeDestroyType::DestroyZero;
      return false;
    }
    n->is_(kperm,
           compose_transposes(origInput->node()->is(kperm), n->is(kperm)));
    n->replaceInput(0, origInput->node()->input());
    if (origInput->uses().size() == 0) {
      origInput->node()->destroy();
    }
    destroy_current = NodeDestroyType::DestroyZero;
    return false;
  }
};

}  // namespace optimization
}  // namespace ONNX_NAMESPACE
